import torch
import torch.nn as nn

class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        logits = self.linear(x)
        if logits.shape[1] == 1:
            return torch.sigmoid(logits)
        else:
            return torch.softmax(logits, dim=1)